import os
import logging
import time
import joblib
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import openpyxl
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import SGDClassifier
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV, StratifiedKFold
from sklearn.metrics import (
    classification_report, confusion_matrix, roc_auc_score,
    roc_curve, precision_recall_curve, average_precision_score,
    balanced_accuracy_score, ConfusionMatrixDisplay
)

def pca_transform(X, pca, scaler, n_components):
    X_scaled = scaler.transform(X)
    X_pca = pca.transform(X_scaled)[:, :n_components]
    cols = [f'PC{i+1}' for i in range(n_components)]
    return pd.DataFrame(X_pca, columns=cols)

def pca_fit_transform(X, variance_threshold=0.90):
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)

    pca = PCA()
    X_pca = pca.fit_transform(X_scaled)

    explained = np.array(pca.explained_variance_ratio_, dtype=float).cumsum()
    n_components = (explained < variance_threshold).sum() + 1
    X_pca_cut = X_pca[:, :n_components]
    cols = [f'PC{i+1}' for i in range(n_components)]
    df_pca = pd.DataFrame(X_pca_cut, columns=cols)

    loadings = pd.DataFrame(pca.components_[:n_components],
                            columns=X.columns,
                            index=cols)

    return df_pca, cols, pca, scaler, loadings

def plot_pca_variance(pca_model, save_path, variance_threshold=0.90):
    explained_variance = np.cumsum(pca_model.explained_variance_ratio_)
    plt.figure(figsize=(8, 6))
    plt.plot(explained_variance, marker='o', label='Varianza Cumulativa')
    plt.axhline(y=variance_threshold, color='red', linestyle='--', label=f'Soglia {variance_threshold:.0%}')
    plt.xlabel('Numero Componenti')
    plt.ylabel('Varianza Cumulativa')
    plt.title('PCA - Varianza Cumulativa Spiegata')
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def plot_confusion_matrix(matrix, phase, classifier_name, output_dir):
    disp = ConfusionMatrixDisplay(confusion_matrix=matrix, display_labels=[0, 1])
    disp.plot(cmap=plt.cm.Blues)
    plt.title(f"{classifier_name} - {phase} Confusion Matrix")
    plt.savefig(os.path.join(output_dir, f"{classifier_name}_{phase}_confusion_matrix.png"))
    plt.close()

def plot_roc_curve(y_true, y_prob, phase, classifier_name, output_dir):
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    auc = roc_auc_score(y_true, y_prob)

    plt.figure()
    plt.plot(fpr, tpr, label=f"AUC = {auc:.2f}")
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f"{classifier_name} - {phase} ROC Curve")
    plt.legend(loc="lower right")
    plt.savefig(os.path.join(output_dir, f"{classifier_name}_{phase}_roc_curve.png"))
    plt.close()

def plot_precision_recall(y_true, y_prob, phase, classifier_name, output_dir):
    precision, recall, _ = precision_recall_curve(y_true, y_prob)
    ap = average_precision_score(y_true, y_prob)

    plt.figure()
    plt.plot(recall, precision, label=f"AP = {ap:.2f}")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"{classifier_name} - {phase} Precision-Recall Curve")
    plt.legend(loc="upper right")
    plt.savefig(os.path.join(output_dir, f"{classifier_name}_{phase}_precision_recall_curve.png"))
    plt.close()

def plot_and_save_results(results_dict, output_dir, feature_selection_method='PCA'):

    # Trasforma il dizionario annidato in DataFrame
    rows = []
    for clf_name, clf_results in results_dict.items():
        for set_name in ['train', 'validation']:
            metrics = clf_results.get(set_name, {})
            row = {
                'Feature Selection': feature_selection_method,
                'Classifier': clf_name,
                'Set': set_name,
                'ROC AUC': metrics.get('roc_auc', None),
                'Accuracy': metrics.get('accuracy', None),
                'bal accuracy': metrics.get('bal_accuracy', None),
                'F1-score 0': metrics.get('F1_score_0', None),
                'F1-score 1': metrics.get('F1_score_1', None),
                'Precision 0': metrics.get('precision_0', None),
                'Precision 1': metrics.get('precision_1', None),
                'Recall 0': metrics.get('recall_0', None),
                'Recall 1': metrics.get('recall_1', None),
                
            }
            rows.append(row)

    results_df_flat = pd.DataFrame(rows)
    # Salvataggio del file aggregato
    summary_path = os.path.join(output_dir, "summary_auc_completo.xlsx")
    pd.DataFrame(results_df_flat).to_excel(summary_path, index=False)
    print(f"\n Report globale salvato in: {summary_path}")
     
    os.makedirs(output_dir, exist_ok=True)

    for metric_col in ['Accuracy', 'ROC AUC']:
        plt.figure(figsize=(12, 6))
        palette = ["#1f77b4", "#ff7f0e"]
        sns.set_palette(palette)
        sns.barplot(data=results_df_flat, x='Classifier', y=metric_col,
                    hue='Set', hue_order=["train", "validation"])
        plt.title(f'Confronto tra classificatori - {feature_selection_method} ({metric_col})')
        plt.ylabel(metric_col)
        plt.ylim(0, 1)
        plt.xticks(rotation=45)
        plt.legend(title='Set')
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        plt.tight_layout()

        # Salvataggio
        filename = f"{metric_col.replace('-', '').lower()}_comparison.png"
        filepath = os.path.join(output_dir, filename)
        plt.savefig(filepath)
        plt.close()

        print(f"Grafico salvato: {filepath}")

def evaluate_model(model, X, y, phase, output_dir, classifier_name):
    y_pred = model.predict(X)
    #y_prob = model.predict_proba(X)[:, 1] if hasattr(model, "predict_proba") else None

    report = classification_report(y, y_pred, output_dict=True)
    matrix = confusion_matrix(y, y_pred)
    #auc = roc_auc_score(y, y_prob) if y_prob is not None else 'N/A'    
    accuracy = report['accuracy']
    f1_score_classe1 = report['1.0']['f1-score']
    f1_score_classe0 = report['0.0']['f1-score']
    recall_0 = report['0.0']['recall']
    recall_1 = report['1.0']['recall']
    precision_0 = report['0.0']['precision']
    precision_1 = report['1.0']['precision']
    bal_accuracy = balanced_accuracy_score(y, y_pred)
    """if isinstance(auc, float):
        score = 0.5 * f1_score_classe1 + 0.3 * auc + 0.2 * bal_accuracy
    else:
        score = 0.7 * f1_score_classe1 + 0.3 * bal_accuracy  # ignora AUC"""
        # Ottieni le probabilità se disponibili
    if hasattr(model, "predict_proba"):
        y_score = model.predict_proba(X)[:, 1]
    elif hasattr(model, "decision_function"):
        y_score = model.decision_function(X)
    else:
        y_score = None
        
    # Calcolo sicuro dell'AUC
    try:
        auc = roc_auc_score(y, y_score) if y_score is not None and len(np.unique(y)) > 1 else 'N/A'
    except Exception as e:
        print(f"Errore nel calcolo AUC per {classifier_name} in fase TEST: {e}")
        auc = 'N/A'
        
    # Save textual results
    pd.DataFrame(report).transpose().to_excel(
        os.path.join(output_dir, f"{classifier_name}_{phase}_report.xlsx"))
    cm_df = pd.DataFrame(matrix, columns=['Pred 0', 'Pred 1'], index=['True 0', 'True 1'])
    cm_df.to_csv(os.path.join(output_dir, f"{classifier_name}_{phase}_confusion_matrix.csv"))
    cm_df.transpose().to_excel(os.path.join(output_dir, f"{classifier_name}_{phase}_confusion_matrix.xlsx"))
    
    
    # Visualizations
    plot_confusion_matrix(matrix, phase, classifier_name, output_dir)
    if y_score is not None:
        plot_roc_curve(y, y_score, phase, classifier_name, output_dir)
        plot_precision_recall(y, y_score, phase, classifier_name, output_dir)

    return {
        'classification_report': report,
        'confusion_matrix': matrix.tolist(),
        'roc_auc': auc,
        'accuracy': accuracy,
        'F1_score_1': f1_score_classe1,
        'F1_score_0': f1_score_classe0,
        'precision_0': precision_0,
        'precision_1': precision_1,
        'recall_0': recall_0,
        'recall_1': recall_1,
        'bal_accuracy': bal_accuracy,
        #'score': score
    }
    
def test(model, clf_name, feature_selection_name, X_test, y_test, output_dir):
    
    output_dir = os.path.join(output_dir,'risultati_test')
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"\n[TESTING] Classificatore: {clf_name} | FS: {feature_selection_name}")
    X_test=X_test.values
    # Predizione
    y_pred = model.predict(X_test)
    #y_prob = model.predict_proba(X_test)[:, 1]
    test_report = classification_report(y_test, y_pred, output_dict=True)
    matrix = confusion_matrix(y_test, y_pred)
    accuracy = test_report['accuracy']
    f1_score_classe1 = test_report['1.0']['f1-score']
    f1_score_classe0 = test_report['0.0']['f1-score']
    recall_0 = test_report['0.0']['recall']
    recall_1 = test_report['1.0']['recall']
    precision_0 = test_report['0.0']['precision']
    precision_1 = test_report['1.0']['precision']
    bal_accuracy = balanced_accuracy_score(y_test, y_pred)
    
    # Ottieni le probabilità se disponibili
    if hasattr(model, "predict_proba"):
        y_score = model.predict_proba(X_test)[:, 1]
    elif hasattr(model, "decision_function"):
        y_score = model.decision_function(X_test)
    else:
        y_score = None
        
    # Calcolo sicuro dell'AUC
    try:
        auc = roc_auc_score(y_test, y_score) if y_score is not None and len(np.unique(y_test)) > 1 else 'N/A'
    except Exception as e:
        print(f"Errore nel calcolo AUC per {clf_name} in fase TEST: {e}")
        auc = 'N/A'
        
    # Salvataggio report in Excel
    pd.DataFrame(test_report).transpose().to_excel(
        os.path.join(output_dir, f"{clf_name}_TEST_report.xlsx"))
    cm_df = pd.DataFrame(matrix, columns=['Pred 0', 'Pred 1'], index=['True 0', 'True 1'])
    cm_df.transpose().to_excel(os.path.join(output_dir, f"{clf_name}_TEST_confusion_matrix.xlsx"))

    # Salvataggio plot
    plot_confusion_matrix(matrix,'TEST', clf_name, output_dir)
    if y_score is not None and isinstance(auc, float):
        plot_roc_curve(y_test, y_score, 'TEST', clf_name, output_dir)
        plot_precision_recall(y_test, y_score, 'TEST', clf_name, output_dir)
        
    # Calcolo metriche
    risultati_test = {
        'test_report': test_report,
        'confusion_matrix': matrix.tolist(),
        'roc_auc': auc,
        'accuracy': accuracy,
        'bal_accuracy': bal_accuracy,
        'F1_score_1': f1_score_classe1,
        'F1_score_0': f1_score_classe0,
        'precision_0': precision_0,
        'precision_1': precision_1,
        'recall_0': recall_0,
        'recall_1': recall_1,
    }
        #risultati['test'] = risultati_test
    lista_risultati=[]
    # Aggiungi solo le chiavi che NON sono 'confusion_matrix' e 'test_report'
    lista_risultati.append({k: v for k, v in risultati_test.items() if k not in ['confusion_matrix', 'test_report']})
    df_finale = pd.DataFrame(lista_risultati)
    path_finale = os.path.join(output_dir, "risultati_del_test.xlsx")
    df_finale.to_excel(path_finale, index=False)
    
    return risultati_test
   
def esegui_classificatori_lineari( X_train, y_train, X_val, y_val, output_dir):
    
    X_train = X_train.values
    X_val = X_val.values
    #X_test = X_test.values

    classifiers = {
        "Naive Bayes": GaussianNB(),
        "SVM Linear": SVC(kernel='linear', C=1, gamma=1, probability=True),
        "SGD": SGDClassifier()
        }

    param_grid = {
        "Naive Bayes": {},  
        "SVM Linear": {
            "C": [0.1, 1, 10, 100],
            "gamma": ['scale', 'auto']
            },
        "SGD": {
            "alpha": [1e-5, 1e-4, 1e-3],
            "loss": ['hinge', 'log_loss'],
            "penalty": ['l2', 'l1', 'elasticnet']
        }
        }

    risultati = {}
    report_aggregato = []
    best_clf_for_fs={}
    best_auc = -np.inf
    best_mod = None
    best_clf_name = "" 
    
    for name, clf in classifiers.items():
        
        print(f"\n Training {name}...")
        classifier_dir = os.path.join(output_dir, name)
        os.makedirs(classifier_dir, exist_ok=True)
        
        start_time = time.time()
        if param_grid[name]:  
            print(f" - Eseguo GridSearchCV per {name}")
            cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
            grid = GridSearchCV(clf, param_grid[name], cv=cv, scoring='roc_auc', n_jobs=-1)
            grid.fit(X_train, y_train)
            best_model = grid.best_estimator_

            # Salva i parametri ottimali su file
            with open(os.path.join(classifier_dir, f"{name}_best_params.txt"), "w") as f:
                f.write(str(grid.best_params_))
        else:
            best_model = clf.fit(X_train, y_train)

        end_time = time.time()
        tempo = end_time - start_time
        print(f"Tempo di Tuning: {tempo:.2f} secondi")
        
        # calcolo e salvataggio parametri training
        y_train_pred = best_model.predict(X_train)
        
        # Ottieni le probabilità se disponibili
        if hasattr(best_model, "predict_proba"):
            y_score = best_model.predict_proba(X_train)[:, 1]
        elif hasattr(best_model, "decision_function"):
            y_score = best_model.decision_function(X_train)
        else:
            y_score = None
        
        try:
            auc = roc_auc_score(y_train, y_score) if y_score is not None and len(np.unique(y_train)) > 1 else 'N/A'
        except Exception as e:
            print(f"Errore nel calcolo AUC per {name} in fase Training: {e}")
            auc = 'N/A'
            
        bal_accuracy_t = balanced_accuracy_score(y_train, y_train_pred)
        report_train = classification_report(y_train, y_train_pred, output_dict=True)
        matrix = confusion_matrix(y_train, y_train_pred)
        accuracy_t = report_train['accuracy']
        f1_score_classe1_t = report_train['1.0']['f1-score']
        f1_score_classe0_t = report_train['0.0']['f1-score']
        recall_0_t = report_train['0.0']['recall']
        recall_1_t = report_train['1.0']['recall']
        precision_0_t = report_train['0.0']['precision']
        precision_1_t = report_train['1.0']['precision']

        report_train_df = pd.DataFrame(report_train).transpose()
        report_train_df.to_excel(os.path.join(classifier_dir, f"{name}_training_report.xlsx"))
        train_results = {
            'classification_report': report_train,
            'confusion_matrix': matrix.tolist(),
            'roc_auc': auc,
            'accuracy': accuracy_t,
            'F1_score_1': f1_score_classe1_t,
            'F1_score_0': f1_score_classe0_t,
            'precision_0': precision_0_t,
            'precision_1': precision_1_t,
            'recall_0': recall_0_t,
            'recall_1': recall_1_t,
            'bal_accuracy': bal_accuracy_t
            }

        joblib.dump(best_model, os.path.join(classifier_dir, f"{name}_model.pkl"))
        val_results = evaluate_model(best_model, X_val, y_val, 'validation', classifier_dir, name)
        #test_results = evaluate_model(best_model, X_test, y_test, 'test', classifier_dir, name)
        
        if val_results['roc_auc'] > best_auc:
            best_auc = val_results['roc_auc'] 
            best_mod = best_model
            best_clf_name = name
            
        risultati[name] = {
            'train': train_results,
            'validation': val_results,
            #'test': test_results
        }
        
        best_clf_for_fs = {
            "model": best_mod,
            "best auc": best_auc,
            "classif_name": best_clf_name
        }
        
        # Save global report summary
        report_aggregato.append({
                'Classifier': name,
                'Train ROC AUC': train_results['roc_auc'],
                'Train Accuracy': train_results['accuracy'],
                'Train Precision_0': train_results['precision_0'],
                'Train Precision_1': train_results['precision_1'],
                'Train Recall_0': train_results['recall_0'],
                'Train Recall_1': train_results['recall_1'],
                'Train F1 score 0': train_results['F1_score_0'],
                'Train F1 score 1': train_results['F1_score_1'],
                'Val ROC AUC': val_results['roc_auc'],
                'Val Accuracy': val_results['accuracy'],
                'Val Precision_0': val_results['precision_0'],
                'Val Precision_1': val_results['precision_1'],
                'Val Recall_0': val_results['recall_0'],
                'Val Recall_1': val_results['recall_1'],
                'Val F1 score 0': val_results['F1_score_0'],
                'Val F1 score 1': val_results['F1_score_1'],
        })

    pd.DataFrame(report_aggregato).to_excel(os.path.join(output_dir, "summary_auc_report.xlsx"), index=False)
    
    return risultati, best_clf_for_fs

# funzione principale  
def PCA_metodo_e_train(training_set_path, validation_set_path, test_set_path,
                       output_dir, variance_threshold=0.90):
    
    training_set = pd.read_csv(training_set_path, index_col=0)
    validation_set = pd.read_csv(validation_set_path, index_col=0)
    test_set = pd.read_csv(test_set_path, index_col=0)
    
    # Separazione X e y
    X_train = training_set.drop(columns=['Label']).select_dtypes(include=[np.number])
    y_train = training_set['Label']

    X_val = validation_set.drop(columns=['Label']).select_dtypes(include=[np.number])
    y_val = validation_set['Label']

    X_test = test_set.drop(columns=['Label']).select_dtypes(include=[np.number])
    y_test = test_set['Label']

    print("\nEseguo PCA sul training set...")
    df_pca_train, cols, pca_model, scaler, loadings = pca_fit_transform(X_train, variance_threshold=variance_threshold)

    print("Applico PCA su validation e test set...")
    df_pca_val = pca_transform(X_val, pca_model, scaler, len(cols))
    df_pca_test = pca_transform(X_test, pca_model, scaler, len(cols))

    # Aggiunta Label
    df_pca_train['Label'] = y_train.values
    df_pca_val['Label'] = y_val.values
    df_pca_test['Label'] = y_test.values

    # Salvataggi
    df_pca_train.to_excel(os.path.join(output_dir, "train_PCA.xlsx"), index=False)
    df_pca_val.to_excel(os.path.join(output_dir, "validation_PCA.xlsx"), index=False)
    df_pca_test.to_excel(os.path.join(output_dir, "test_PCA.xlsx"), index=False)
    loadings.to_excel(os.path.join(output_dir, "loadings_PCA.xlsx"))
    joblib.dump(pca_model, os.path.join(output_dir, "pca_model.pkl"))
    joblib.dump(scaler, os.path.join(output_dir, "scaler.pkl"))
    
    # Top 5 feature per PC
    top_features_dict = {}
    for pc in cols:
        pc_loadings = loadings.loc[pc].abs().sort_values(ascending=False)
        top_features_dict[pc] = pc_loadings.head(5).index.tolist()

    df_top = pd.DataFrame.from_dict(top_features_dict, orient='index').transpose()
    df_top.to_excel(os.path.join(output_dir, "top5_features_per_PC.xlsx"), index_label="PC")

    # Varianza spiegata
    plot_pca_variance(pca_model, os.path.join(output_dir, 'pca_varianza.png'), variance_threshold)

    try:
        risultati, best_clf_for_fs = esegui_classificatori_lineari(
            df_pca_train.drop(columns=['Label']),
            y_train,
            df_pca_val.drop(columns=['Label']),
            y_val,
            output_dir
        )
    except Exception as e:
        logging.exception(f"Errore nel processo di training-validation-test del dataset PCA")
    
    risultati_test = test(
        model=best_clf_for_fs["model"],
        clf_name=best_clf_for_fs["classif_name"],
        feature_selection_name = 'PCA',
        X_test=df_pca_test.drop(columns=['Label']),
        y_test=y_test,
        output_dir=output_dir
    )
    
    plot_and_save_results(risultati, output_dir)
    
    print(f'\nMetodo PCA applicato! Trovi i risultati in: {output_dir}')
